Fix --gpu_device mps falling back to CPU (fixes #1455)#1462
Open
gaoflow wants to merge 1 commit into
Open
Conversation
When the CLI is run with `--gpu_device mps`, the literal string "mps"
is passed through to `_use_gpu_torch` as the device index. It then
builds `torch.device("mps:mps")`, which is not a valid device string
and raises, so GPU detection fails and cellpose silently falls back to
CPU -- even though MPS is available (omitting the flag works).
Treat a non-integer gpu_number as a request for the default MPS device
(`torch.device("mps")`) instead of building `"mps:<gpu_number>"`.
Integer indices keep their existing `"mps:<n>"` behaviour.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #1455. Running cellpose with
--gpu_device mpson Apple silicon silently uses the CPU, while omitting the flag correctly uses MPS:Root cause
__main__passesdevice=args.gpu_device(the string"mps") intoassign_device, which keeps it as the string"mps"and callsuse_gpu(gpu_number="mps")._use_gpu_torchthen treatsgpu_numberas a device index and builds:Both raise, so
_use_gpu_torchreturnsFalseandassign_devicefalls back to CPU. The default path works only becausegpu_devicedefaults to"0", which becomestorch.device("mps:0").Fix
In
_use_gpu_torch, treat a non-integergpu_number(i.e. the"mps"sentinel) as a request for the default MPS devicetorch.device("mps"), instead of building"mps:<gpu_number>". Integer indices keep their existing"mps:<n>"behaviour.Verification
Reproduced on Apple silicon (MPS available):
Added a regression test
test_assign_device_mps_stringintests/test_import.py(skipped when MPS is unavailable). It fails on the current code (assert gpu->False) and passes with this change. The integer/default device paths are unchanged.